import time
import datetime
import csv
import platform
import glob
import os
import pathlib
from threading import Lock, Thread
from queue import Queue, Empty
from quant.utils import Timer2, catch_exception, abstract_method, get_logs_path


class _SavingTime:
    def __init__(self, minutes_per_file):
        self.minutes_per_file = minutes_per_file
        self._next_time = None

    def saving_time(self, ts):
        if self._next_time is None:
            self._next_time = self._first_saving_time(ts)
        return self._next_time

    def last_save_time(self, ts):
        next_save_time = self.saving_time(ts)
        last_save_time = next_save_time - self.minutes_per_file * 60
        return last_save_time

    def has_saved(self, ts):
        while ts > self._next_time:
            self._next_time += 60 * self.minutes_per_file

    def _first_saving_time(self, _ts):
        if self.minutes_per_file >= 24 * 60:
            now = datetime.datetime.fromtimestamp(_ts)
            next_time = datetime.datetime(now.year, now.month, now.day) + datetime.timedelta(
                minutes=self.minutes_per_file)
            return next_time.timestamp()

        ts = int(_ts) // 60 * 60
        dt = datetime.datetime.fromtimestamp(ts)
        m = int(dt.strftime('%M'))
        next_m = (m // self.minutes_per_file + 1) * self.minutes_per_file

        lack_m = next_m - m
        next_dt = dt + datetime.timedelta(minutes=lack_m)

        return next_dt.timestamp()


class LineLog:
    _is_windows = platform.system() == 'Windows'
    _Save = '_Save'

    def __init__(self, name, fields, path=None, file_minutes=60, history_hours=24):
        if path is None:
            path = pathlib.Path(__file__).parent.parent.parent.joinpath('logs')
            path = path.as_posix()

        self.name, self.path = name, path
        self.keep_history_hours = history_hours

        self._saving_time = _SavingTime(file_minutes)
        self._fields = ['time'] + fields

        self._f = None
        self._csv = None

        self._queue = Queue()
        self._lock = Lock()
        self._start()
        self._latest = None
        Timer2(self._del_outdated, 3600).start()

    def log_line(self, line, ts=None):
        if ts is None:
            ts = time.time()

        if self._f is None:
            self._f = self._new_f(ts)
            self._csv = self._new_csv(self._f)

        self._latest = ts
        self._queue.put((ts, line))

    def save(self):
        self._queue.put((None, self._Save))

    def _start(self):
        Thread(target=self._run, daemon=True).start()

    def _run(self):
        while True:
            ts, line = self._queue.get()
            if line == self._Save:
                if self._f is not None:
                    self._f.flush()
                continue

            try:
                with self._lock:
                    self._check_time(ts)
                    line.insert(0, self._time(ts))
                    self._csv.writerow(line)
            except:
                catch_exception()

    def _format_file_name(self, ts):
        dt = datetime.datetime.fromtimestamp(ts)
        index = dt.strftime('%m.%d-%H.%M')
        file = '{}/{}-{}.csv'.format(self.path, self.name, index)
        return file

    def _new_f(self, ts):
        file_name = self._format_file_name(self._saving_time.last_save_time(ts))

        if self._is_windows:
            f = open(file_name, 'w', newline='')
        else:
            f = open(file_name, 'w')
        return f

    def _new_csv(self, f):
        new_csv = csv.writer(f)
        new_csv.writerow(self._fields)
        return new_csv

    def _check_time(self, ts):
        if ts > self._saving_time.saving_time(ts):
            self._f.close()
            self._saving_time.has_saved(ts)
            self._f = self._new_f(ts)
            self._csv = self._new_csv(self._f)

    def _time(self, ts):
        dt = datetime.datetime.fromtimestamp(ts)
        now = dt.strftime('%H:%M:%S')
        ms = '{:.3f}'.format(ts % 1)

        t = f'{now}|{ms[2:]}'
        return t

    def _del_outdated(self):
        if self._latest is None:
            return
        now_dt = datetime.datetime.fromtimestamp(self._latest)
        cut_dt = now_dt - datetime.timedelta(hours=self.keep_history_hours)

        cut_date_time = cut_dt.strftime('%m.%d-%H.%M')

        del_files = []
        all_files = glob.glob(self.path + '/{}*.csv'.format(self.name))

        year_leap = now_dt.year != cut_dt.year
        within_month = now_dt.month == cut_dt.month
        month = now_dt.strftime('%m')

        for file in all_files:
            splits = file.split('-')
            data_time_i = '{}-{}'.format(splits[-2], splits[-1].replace('.csv', ''))
            if data_time_i < cut_date_time:
                del_files.append(file)
            elif year_leap and data_time_i.startswith('12'):
                del_files.append(file)
            elif within_month and (not data_time_i.startswith(str(month))):
                del_files.append(file)

        for file in del_files:
            try:
                os.remove(file)
            except PermissionError:
                pass


class LineLog2:
    def __init__(self, name, fields, path=None, file_minutes=60, history_hours=24):
        if path is None:
            path = get_logs_path()
        self._line_log = LineLog(name, fields, path, file_minutes, history_hours)
        self._name_to_col = {f: i for i, f in enumerate(fields)}
        self._col_to_name = {v: k for k, v in self._name_to_col.items()}
        self._ache_value = {}
        self._len = len(fields)

    def log_line(self, line, ts=None):
        self._line_log.log_line(line, ts)

    def log_line_with_ache(self, line, ts=None):  # do not change ache even if different
        this = [' '] * len(self._name_to_col)
        for n, v in self._ache_value.items():
            i = self._name_to_col[n]
            this[i] = v

        for i, v in enumerate(line):
            if v != '' and v != ' ':
                this[i] = v

        self._line_log.log_line(this, ts)

    def log_value(self, name, value, ts=None):
        name_to_col = self._name_to_col
        line = [' '] * self._len

        for n, v in self._ache_value.items():
            i = name_to_col[n]
            line[i] = v

        line[name_to_col[name]] = value
        self._line_log.log_line(line, ts)

    def log_value_dict(self, dic, ts=None):
        name_to_col = self._name_to_col
        line = [' '] * self._len

        for n, v in self._ache_value.items():
            i = name_to_col[n]
            line[i] = v

        for n, v in dic.items():
            line[name_to_col[n]] = v
        self._line_log.log_line(line, ts)

    def log_ache(self, name, value):
        self._ache_value[name] = value

    def log_ache_dic(self, dic):
        self._ache_value.update(dic)

    def flush_ache(self):
        name_to_col = self._name_to_col
        line = [' '] * self._len

        for n, v in self._ache_value.items():
            i = name_to_col[n]
            line[i] = v
        self._line_log.log_line(line)
        return line

    def clear_ache(self):
        self._ache_value.clear()

    def save(self):
        self._line_log.save()


class RawLogger:
    def __init__(self, name):
        self.name = name

    @abstract_method
    def log_raw(self, event, exchange, symbol, frequency, recv_time, raw):
        pass

    @abstract_method
    def iter_raw(self):
        event, exchange, symbol, frequency, recv_time, raw = '123456'
        yield event, exchange, symbol, frequency, recv_time, raw

    @abstract_method
    def drop_raw(self):
        """
            写入中支持删除。参见FileRawLogger: file.close()
        """
        pass

    @abstract_method
    def save(self):
        pass

    def cut_raw(self, cut_name, begin: datetime.datetime, end: datetime.datetime):
        begin = begin.timestamp()
        end = end.timestamp()
        new_logger = type(self)(cut_name)
        new_logger.drop_raw()

        for raw in self.iter_raw():
            ts = raw[4]
            if ts > end:
                break
            if ts >= begin:
                new_logger.log_raw(*raw)

        new_logger.save()


class MongoRawLogger(RawLogger):
    def __init__(self, name):
        import pymongo
        super().__init__(name)
        self._start_server()
        self._queue = Queue()
        self._thread = Thread(target=self._run, daemon=True)
        self._thread.start()

        self.db = pymongo.MongoClient('localhost', 27017)['raw_data']
        self.collection = self.db[name]

    def log_raw(self, event, exchange, symbol, frequency, recv_time, raw):
        dic = {
            'event': event,
            'exchange': exchange,
            'symbol': symbol,
            'frequency': frequency,
            'recv_time': recv_time,
            'raw': raw,
        }
        self._queue.put(dic)

    def iter_raw(self):
        for i in self.collection.find():
            yield i['event'], i['exchange'], i['symbol'], i['frequency'], i['recv_time'], i['raw']

    def drop_raw(self):
        self.collection.remove()

    def _start_server(self):
        import psutil
        import pathlib

        programs = [p.as_dict(attrs=['name'])['name'] for p in psutil.process_iter()]
        is_running = 'mongod' in programs
        if is_running:
            return

        path = pathlib.Path(__file__)
        config_file = '{}{}mongo_config.txt'.format(str(path.parent), os.sep)
        os.system('mongod -f {}'.format(config_file))

    def _run(self):
        while True:
            dic = self._queue.get()
            self.collection.insert_one(dic)


class FileRawLogger(RawLogger):
    Split = '|||'

    def __init__(self, name, log_path=None):
        super(FileRawLogger, self).__init__(name)

        if log_path is None:
            path = pathlib.Path(__file__)
            log_path = path.parent.parent.parent.parent.joinpath('raw')
        else:
            log_path = pathlib.Path(log_path)

        if not os.path.exists(log_path):
            os.mkdir(log_path)

        file_name = log_path.joinpath('{}.txt'.format(name))
        self._file_name = file_name.as_posix()
        self._write_file = None

        self._queue = Queue()
        Thread(target=self._run, daemon=True).start()

    def log_raw(self, event, exchange, symbol, frequency, recv_time, raw):
        self._queue.put([event, exchange, symbol, frequency, recv_time, raw])

    def iter_raw(self):
        read_file = open(self._file_name, 'r')
        sep = self.Split

        while True:
            text = read_file.readline()
            if text == '':
                return
            event, exchange, symbol, frequency, recv_time, raw = text.split(sep)
            recv_time = float(recv_time)
            raw = raw[:-1]
            yield event, exchange, symbol, frequency, recv_time, raw

    def drop_raw(self):
        while not self._queue.empty():
            self._queue.get()

        if self._write_file is not None:
            self._write_file.close()
            self._write_file = None

        try:
            os.remove(self._file_name)
        except FileNotFoundError:
            pass

    def save(self):
        if self._write_file is not None:
            print(self.name, 'saved')
            self._write_file.flush()

    def _run(self):
        count = 0
        while True:
            count += 1
            line = self._queue.get()
            if self._write_file is None:
                self._write_file = open(self._file_name, 'a')

            line[4] = str(line[4])
            text = self.Split.join(line) + '\n'
            self._write_file.write(text)
            if count > 100:
                self._write_file.flush()
                count = 0


class RawRecorder:
    IsRecording = False

    def __init__(self, markets, logger: RawLogger):
        self.markets = markets
        self.logger = logger
        self._started = False

    def start(self):
        if self._started:
            return

        type(self).IsRecording = True
        if RawReplayer.IsReplaying:
            err = 'Cannot Record while Replaying!'
            raise Exception(err)

        self.logger.drop_raw()
        self.markets.info_engine.subscribe_raw(self._on_raw)

    def stop(self):
        self.markets.info_engine.unsubscribe_raw(self._on_raw)

    def _on_raw(self, event, exchange, symbol, frequency, recv_time, raw):
        self.logger.log_raw(event, exchange, symbol, frequency, recv_time, raw)


class RawReplayer:
    IsReplaying = False

    def __init__(self, markets, logger: RawLogger):
        self.markets = markets
        self.logger = logger
        self.data_iter = logger.iter_raw()
        self._last_show_ts = 0

    def start(self, show_time=True):
        type(self).IsReplaying = True
        if RawRecorder.IsRecording:
            err = 'Cannot Replay while Recording!'
            raise Exception(err)

        for raw in self.logger.iter_raw():
            self.markets.feed_raw(*raw)
            if show_time:
                self._print_time(raw[4])

    def push_one(self):
        raw = next(self.data_iter)
        self.markets.feed_raw(*raw)

    def push_many(self, n):
        for i in range(n):
            self.push_one()

    def push_interval(self, interval):
        last_ts = None
        count = 0

        if self.markets.now is None:
            raw_0 = self.logger.iter_raw().__next__()
            ts = raw_0[4]
            self.markets.now = ts

        while True:
            try:
                self.push_one()
                count += 1
            except StopIteration:
                return

            now = self.markets.now
            if last_ts is None:
                last_ts = now

            if now - last_ts > interval:
                # print('{} raw pushed:'.format(count))
                return

    def test_push(self, step=1, call=None):
        while True:
            self.push_many(step)
            if call:
                call()
            # input(':')

    def test_push_2(self, interval, call=None):
        while True:
            self.push_interval(interval)
            if call:
                call()
            # input(':')

    def test_push_3(self, step, step_call, loop_call):
        while True:
            loop_call()
            for i in range(step):
                self.push_one()
                step_call()
            input(':')

    def _print_time(self, recv):
        if recv - self._last_show_ts > 60:
            self._last_show_ts = recv // 60 * 60
            str_time = datetime.datetime.fromtimestamp(recv)
            str_time = str_time.strftime('%Y/%m/%d-%H:%M:00')
            print('正在撮合: {}'.format(str_time))


class RpcServer:
    def __init__(self, port=5555):
        self.port = port
        self._functions = {}

    def start(self):
        Thread(target=self._run, daemon=True).start()

    def register(self, name, func):
        self._functions[name] = func

    def _run(self):
        import zmq

        context = zmq.Context()
        socket = context.socket(getattr(zmq, "REP"))
        socket.bind("tcp://*:{}".format(self.port))
        while True:
            name, arg, kwargs = socket.recv_pyobj()
            try:
                func = self._functions[name]
                result = func(*arg, **kwargs)
                socket.send_pyobj(result)
            except:
                catch_exception()
                socket.send_pyobj(None)


class RpcClient:
    _socket = None

    def __init__(self, ip, port):
        self.ip = ip
        self.port = port

    def start(self):
        import zmq
        context = zmq.Context()
        socket = context.socket(getattr(zmq, 'REQ'))
        socket.connect("tcp://localhost:5555")
        self._socket = socket

    def call(self, name, *args, **kwargs):
        self._socket.send_pyobj((name, args, kwargs))
        result = self._socket.recv_pyobj()
        return result


class SimulateInfoServer:
    """
    add_new_function:
    1.write 2 func, 2.register to rpc, 3.create df/file, 4.save df in save()
    """

    def __init__(self, name='sim', port=5555):
        from pandas import DataFrame
        self.rpc_server = RpcServer(port)
        self.rpc_server.start()
        self.name = name
        self._queue = Queue()

        path = get_logs_path().joinpath('sim_result')
        if not os.path.isdir(path.as_posix()):
            os.mkdir(path.as_posix())

        self.profit_file = path.joinpath('{}_profits.csv'.format(self.name)).as_posix()
        self.profit_df = self._read_df(self.profit_file)
        self._rename_df_columns(self.profit_df)

        self.statistics_file = path.joinpath('{}_statistics.csv'.format(self.name)).as_posix()
        self.statistics_df = self._read_df(self.statistics_file)
        self._rename_df_index(self.statistics_df)

        self.deal_file = path.joinpath('{}_deal.csv'.format(self.name)).as_posix()
        self.deal_df = DataFrame()

        Thread(target=self._run, daemon=True).start()
        self.rpc_server.register('write_profits', self.write_profits)
        self.rpc_server.register('write_deal', self.write_deal)
        self.rpc_server.register('write_statistics', self.write_statistics)

    def write_deal(self, time, name, deal):
        self._queue.put(('write_deal', time, name, deal))

    def _write_deal(self, time, name, deal):
        index = len(self.deal_df)
        self.deal_df.loc[index, 'time'] = time
        self.deal_df.loc[index, name] = deal

    def write_statistics(self, name, statistics_map):
        self._queue.put(('write_statistics', name, statistics_map))

    def _write_statistics(self, name, statistics_map):
        for k, v in statistics_map.items():
            self.statistics_df.loc[name, k] = v

    def write_profits(self, name, profits):
        self._queue.put(('write_profits', name, profits))

    def _write_profits(self, name, profits):
        df = self.profit_df
        df[name] = profits
        # print(list(df.columns))

    def save(self):  # call after self.join()
        self.profit_df.to_csv(self.profit_file)
        self.statistics_df.to_csv(self.statistics_file)
        self.deal_df.to_csv(self.deal_file, index=False)

    def join(self):
        self._queue.join()
        self.save()

    def _run(self):
        while True:
            name, *args = self._queue.get()
            func = getattr(self, '_' + name)
            try:
                func(*args)
            except:
                catch_exception()
            self._queue.task_done()

    def _read_df(self, file, **kwargs):
        import pandas as pd
        try:
            df = pd.read_csv(file, index_col=0, **kwargs)
        except FileNotFoundError:
            df = pd.DataFrame()
        except getattr(pd, 'errors').EmptyDataError:
            df = pd.DataFrame()
        return df

    # def _new_col_name(self, names, old_name):
    #     if '.' not in old_name:
    #         return '{}.1'.format(old_name)
    #
    #     *prefix, count = old_name.split('.')
    #     try:
    #         count = int(count)
    #         prefix = '.'.join(prefix)
    #     except ValueError:
    #         count = 0
    #         prefix = old_name
    #
    #     name = old_name
    #     while name in names:
    #         count += 1
    #         name = '{}.{}'.format(prefix, count)
    #
    #     return name

    def _rename_df_columns(self, df):
        appendix = '(p)'
        columns = list(df.columns)
        columns_set = set(columns)
        renames = {}
        for name in columns:
            if appendix not in name:
                renames[name] = name + appendix
                while renames[name] in columns_set:
                    renames[name] += '*'

        df.rename(columns=renames, inplace=True)

    def _rename_df_index(self, df):
        appendix = '(p)'
        index = list(df.index)
        index_set = set(index)
        renames = {}
        for name in index:
            if appendix not in name:
                renames[name] = name + appendix
                while renames[name] in index_set:
                    renames[name] += '*'

        df.rename(index=renames, inplace=True)


class _SimulateInfoClient:
    def __init__(self, port=5555):
        self.client = RpcClient('localhost', port)
        self.client.start()

    def write_profits(self, name, profits):
        self.client.call('write_profits', name, profits)

    def write_deal(self, time, name, deal):
        self.client.call('write_deal', time, name, deal)

    def write_statistics(self, name, **kwargs):
        self.client.call('write_statistics', name, kwargs)


class SimulateInfoCollector:
    def __init__(self, markets):
        self.markets = markets
        self.info_client = _SimulateInfoClient()
        self._collects = []

    def add_account(self, name, account):
        instance = _AccountCollector(name, account, self.markets, self.info_client)
        self._collects.append(instance)

    def collect(self):
        for ins in self._collects:
            ins.collect()


class _AccountCollector:
    _profit_interval = 20
    _profit_count = 2000
    _client = None

    def __init__(self, name, account, markets, info_client):
        from quant.markets import functions

        if type(self)._client is None:
            type(self)._client = _SimulateInfoClient()

        self.name = name
        self.account = account
        self.markets = markets
        self.info_client = info_client
        self.timer = self.markets.timer

        self._get_value_factor = functions.get_value_factor
        self._exchange = account.exchange
        self._statistics = {'place_times': 0, 'deal': 0}
        self._profits = []

        ee = account.info_engine
        ee.subscribe(ee.Place, self.on_place)
        ee.subscribe(ee.Deal, self.on_deal)
        self.timer.register(self.on_timer, self._profit_interval)

    def collect(self):
        profits = self._resize_profits(self._profits, self._profit_count)
        self._client.write_profits(self.name, profits)

        self._statistics['profit'] = self._profits[-1]
        self._client.write_statistics(self.name, **self._statistics)

    def on_deal(self, order, amount):
        self._statistics['deal'] += amount * self._get_value_factor(self._exchange, order.symbol)
        # now = datetime.datetime.fromtimestamp(self.markets.now)
        # now = now.strftime('%Y/%m/%d-%H:%M:%S')
        # deal = '{} {}({})'.format(order.side, order.price, amount)
        # self._client.write_deal(now, self.name, deal)

    def on_place(self, order):
        self._statistics['place_times'] += 1

    def on_timer(self):
        profit = self.account.evaluate()
        self._profits.append(profit)

    def _resize_profits(self, profits, count):
        p_count = len(profits)
        factor = (p_count-1) / (count-1)
        try:
            resized = [profits[int(i*factor)] for i in range(count)]
        except IndexError:
            profits += [1]
            resized = [profits[int(i*factor)] for i in range(count)]
        return resized


class ValueUpdater:
    _quick_mode = False

    @classmethod
    def set_quick_mode(cls):
        cls._quick_mode = True

    def __init__(self, name, update_fn, query_intv):
        from quant.utils import Saving, Timer2

        self.update_fn = update_fn
        self._updated = False
        if self._quick_mode:
            self._updated = True

        self._saving = Saving('ValueUpdater({})'.format(name))
        Timer2(self._update, query_intv).start()

    def get_value(self, name):
        while not self._updated:
            self._update()
            time.sleep(1)

        try:
            return self._saving.get(name)
        except KeyError:
            self._update()

        try:
            return self._saving.get(name)
        except KeyError:
            err = 'ValueUpdater.get_value({}), KeyError.'.format(name)
            raise KeyError(err)

    def _update(self):
        self._updated = True
        for i in range(3):
            try:
                result = self.update_fn()
                self._saving.update(result)
                break
            except:
                from quant.utils import catch_exception
                catch_exception()


class __NoneValue:
    def __repr__(self):
        return 'NoneValue()'

    def __bool__(self):
        return False

    def __lt__(self, other):
        return False

    def __le__(self, other):
        return False

    def __gt__(self, other):
        return False

    def __ge__(self, other):
        return False

    def __eq__(self, other):
        if other is NoneValue:
            return True
        return False

    # -------------
    def __add__(self, other):
        return NoneValue

    def __radd__(self, other):
        return NoneValue

    def __sub__(self, other):
        return NoneValue

    def __rsub__(self, other):
        return NoneValue

    def __mul__(self, other):
        return NoneValue

    def __rmul__(self, other):
        return NoneValue

    def __truediv__(self, other):
        return NoneValue

    def __rtruediv__(self, other):
        return NoneValue

    # -------------
    def __abs__(self):
        return NoneValue

    def __pos__(self):
        return NoneValue

    def __neg__(self):
        return NoneValue

    # -------------
    def __pow__(self, power, modulo=None):
        return NoneValue

    def __rpow__(self, other):
        return NoneValue

    # -------------
    def __mod__(self, other):
        return NoneValue

    def __rmod__(self, other):
        return NoneValue

    def __round__(self, n=None):
        return NoneValue


class RecentRawCollector:
    def __init__(self, name, markets, logger_cls, bulk_minutes=60, keep_minutes=3*24*60):
        self.name = name
        self.markets = markets
        self.logger_cls = logger_cls
        self.bulk_minutes = bulk_minutes
        self.keep_minutes = keep_minutes

        self._keep_num = int(keep_minutes/bulk_minutes)
        self._keep_num = max(self._keep_num, 1)

        self._bulk_count = 0
        self._history_logger = []
        self._this_recorder = None

    def start(self):
        self._build_recorder()
        Timer2(self._on_timer, 1).start()

    def _cal_bulk_count(self):
        return time.time() // (self.bulk_minutes * 60)

    def _build_recorder(self):
        self._bulk_count = self._cal_bulk_count()
        begin = self._bulk_count * self.bulk_minutes * 60
        begin = datetime.datetime.fromtimestamp(begin, datetime.timezone(datetime.timedelta(hours=8)))
        name = '{}-%m.%d-%H.%M'.format(self.name)
        name = begin.strftime(name)

        if self._this_recorder is not None:
            self._this_recorder.stop()
            self._this_recorder.logger.save()

        logger = self.logger_cls(name)
        self._history_logger.append(logger)
        self._this_recorder = RawRecorder(self.markets, logger)
        self._this_recorder.start()

        while len(self._history_logger) > self._keep_num:
            to_close = self._history_logger.pop(0)
            to_close.drop_raw()
            print('drop', to_close.name)

    def _on_timer(self):
        bulk_count = self._cal_bulk_count()
        if bulk_count != self._bulk_count:
            self._build_recorder()


class InfluxClient:
    _singulars = {}

    @classmethod
    def create_singular(cls, host, port=8086, user_name='johnsquant', password='wo5892132156', delay=1):
        _id = (host, port, user_name, password, delay)

        if _id in cls._singulars:
            return cls._singulars[_id]

        instance = cls(*_id)
        cls._singulars[_id] = instance
        return instance

    def __init__(self, host, port=8086, user_name='johnsquant', password='wo5892132156', delay=1):
        from influxdb import InfluxDBClient

        self.client = InfluxDBClient(host, port, user_name, password)
        self.delay = delay

        self._queue = Queue()
        Thread(target=self._run, daemon=True).start()

    def create_database(self, database):
        self.client.create_database(database)

    def drop_database(self, database):
        self.client.drop_database(database)

    def query(self, sql, database):
        return self.client.query(sql, database=database)

    def set_retention_policy(self, database, keep_hours=24):
        if keep_hours is None:
            return

        name = 'policy-1'
        keep = '{}h'.format(keep_hours)
        if keep_hours == 0:
            keep = '0s'

        for p in self.client.get_list_retention_policies(database):
            if p['default'] is True:
                dur = p['duration']
                if dur.split('0m')[0] != keep:
                    self.client.drop_retention_policy(name, database)
                    self.client.create_retention_policy(name, keep, '1', database, True)
                return

        self.client.create_retention_policy(name, keep, '1', database, True)

    def write(self, database, measurement, tags, fields, now=None):
        if now is None:
            now = time.time()
        self._queue.put((database, measurement, tags, fields, now))

    def read(self, database, measurement, from_, to):
        import pandas as pd

        dt0 = datetime.datetime.utcfromtimestamp(from_)
        sql = "SELECT * FROM \"{}\" WHERE time >= '{}'".format(measurement, dt0)

        dt1 = datetime.datetime.utcfromtimestamp(to)
        sql += "AND time < '{}'".format(dt1)

        result = self.client.query(sql, database=database)
        result = result.raw['series']
        if len(result) == 0:
            return None

        result = result[0]
        df = pd.DataFrame(result['values'], columns=result['columns'])
        return df

    def show_tail(self, database, measurement, tail=20):
        import pandas as pd
        pd.set_option('display.max_colwidth', 20)
        pd.set_option('display.max_columns', 20)
        pd.set_option('display.width', 1000)

        try:
            result = self.client.query('select * from {} LIMIT {}'.format(measurement, tail), database=database)
        except Exception as err:
            print('show tail err:{}'.format(err))
            print('Null database')
            return

        raw = result.raw
        series = raw['series']

        if len(series) == 0:
            print('Null measurement')
            return

        assert len(series) == 1
        data = series[0]

        df = pd.DataFrame(data['values'], columns=data['columns'])
        print(df.tail(tail))

    def show_storage(self):
        a = self.query('select sum(diskBytes) / 1024 / 1024 /1024 from _internal."monitor"."shard" where time > now() - 20s group by "database"', None)
        raw = a.raw
        for i in raw['series']:
            print(i)

    def join(self):
        while True:
            if self._queue.empty():
                break
            time.sleep(0.2)
        time.sleep(0.2)

    def _run(self):
        while True:
            try:
                db_to_batch = self._read_batch()
                # for i, j in db_to_batch.items():
                #     print('-------------{}-------------'.format(i))
                #     print(j[:5])
                for db, batch in db_to_batch.items():
                    self._write_batch(db, batch)
            except:
                catch_exception()
            time.sleep(self.delay)

    def _write(self, database, measurement, tags, fields, now):
        point = {
            'measurement': measurement,
            'tags': tags,
            'fields': fields,
            'time': datetime.datetime.utcfromtimestamp(now),
        }
        self.client.write_points([point], database=database, time_precision='u')

    def _read_batch(self):  # {database: (database, measurement, tags, fields, now), }
        queue = self._queue
        db_to_batch = {}

        while not queue.empty():
            line = queue.get()
            db = line[0]
            batch = db_to_batch.setdefault(db, [])
            batch.append(line)

        return db_to_batch

    def _write_batch(self, database, batch):
        points = [
            {
                'measurement': measurement,
                'tags': tags,
                'fields': fields,
                'time': datetime.datetime.utcfromtimestamp(now),
            } for _, measurement, tags, fields, now in batch
        ]

        self.client.write_points(points, database=database, time_precision='u')


class InfluxClientUSIncreasing(InfluxClient):
    _log_ts = 0

    def write(self, database, measurement, tags, fields, now=None):
        if now is None:
            now = time.time()

        log_ts = round(now, 6)
        if log_ts <= self._log_ts:
            log_ts = self._log_ts + 0.000001
            log_ts = round(log_ts, 6)
        self._log_ts = log_ts
        super().write(database, measurement, tags, fields, log_ts)


class Ding:
    default_key = '2187cf3f3020cfd841af96b580df3e4a3a8832d47b50d8b41e3351d0e8aab184'
    url = 'https://oapi.dingtalk.com/robot/send'
    rate_limit = '1/7'

    @classmethod
    def create_default_ding(cls):
        return Ding(cls.default_key, '[event]')

    def __init__(self, key, keyword):
        from dingtalkchatbot.chatbot import DingtalkChatbot
        self.key = key
        self.keyword = keyword
        self.url = f'{type(self).url}?access_token={key}'

        self._ding_chat = DingtalkChatbot(self.url)
        self._rate_limit = self._build_rate_limit()

        self._queue = Queue()
        self._thread = Thread(target=self._run, daemon=True)
        self._thread.start()

    def _build_rate_limit(self):
        n_times, n_secs = type(self).rate_limit.split('/')
        n_times, n_secs = int(n_times), int(n_secs)
        return _RateLimit('ding_bot', n_times, n_secs)

    def send(self, msg):
        self._queue.put(msg)

    def _run(self):
        while True:
            try:
                self._rate_limit.wait_and_act()
                msg = self._next_message()
                if msg is not None:
                    self._ding_chat.send_text(msg)
            except:
                catch_exception()
            time.sleep(1)

    def _next_message(self):
        msg = self.keyword

        msg_li = []
        while True:
            try:
                msg_i = self._queue.get(False)
                msg_li.append(str(msg_i))
            except Empty:
                break
            if len(msg_li) > 10:
                break

        if len(msg_li) == 0:
            return None
        if len(msg_li) > 1:
            msg += '\n'
        msg += '\n---------\n'.join(msg_li)
        return msg


class _RateLimit:
    def __init__(self, name, times, seconds):
        self._name = name
        self._n_times = times
        self._n_seconds = seconds
        self.__action_record = []

        self.__lock = Lock()

    def __repr__(self):
        return f'RateLimit({self._name}, {self._n_times}, {self._n_seconds})'

    def __cut_record(self):
        while len(self.__action_record) > 0:
            this_interval = time.time() - self.__action_record[0]
            if this_interval > self._n_seconds:
                self.__action_record.pop(0)
            else:
                break

    @classmethod
    def create_with_des(cls, name, times_per_secs):
        times, seconds = times_per_secs.split('/')
        times, seconds = int(times), int(seconds)
        return cls(name, times, seconds)

    def act(self):
        assert self.times_left() > 0
        self.__action_record.append(time.time())
        self.__cut_record()

    def act_allow_exceed(self):
        self.__action_record.append(time.time())
        self.__cut_record()

    def times_left(self):
        self.__cut_record()
        return self._n_times - len(self.__action_record)

    def has_access(self):
        return self.times_left() > 0

    def retry_after(self):
        times_left = self.times_left()
        if times_left > 0:
            return 0
        time_0 = self.__action_record[0]
        pop_0_time = time_0 + self._n_seconds
        retry_after = pop_0_time - time.time()
        if retry_after < 0:
            return 0
        return retry_after

    def wait_and_act(self):
        # while self.times_left() <= 0:        --------------
        #     time.sleep(self.retry_after())          |      未加锁版本，导致多线程同时.act()
        # self.act()                           --------------

        lock_got = self.__lock.acquire(True, 100)
        if not lock_got:
            raise Exception('Lock not got for 100 seconds')

        while self.times_left() <= 0:
            time.sleep(self.retry_after())
        self.act()

        self.__lock.release()

    def assert_attr(self, times, seconds):
        assert self._n_times == times and self._n_seconds == seconds


NoneValue = __NoneValue()


if __name__ == '__main__':
    from quant.utils import set_test_mode
    from quant.const import *
    set_test_mode()

    def try_raw_logger():
        lo = MongoRawLogger('name1')
        lo.log_raw(Event.Book, Exchange.Binance, 'matic/usdt.swap', DataFrequency.Low, time.time(), 'raw')
        input(':')

    def try_profit_calculator():
        import keys
        from quant.accounts import Accounts
        from quant.accounts import Order

        acc = Accounts().create_account('Binance', keys.null_key)
        cal = ProfitCalculator(acc)
        o1 = Order('btc/usdt', 'buy', 10000, 111)
        acc.info_engine.put_deal(o1, 1)
        o1 = Order('btc/usdt', 'sell', 10001, 115)
        acc.info_engine.put_deal(o1, 1)
        o1 = Order('btc/usdt', 'buy', 10005, 111)
        acc.info_engine.put_deal(o1, 3)
        o1 = Order('btc/usdt', 'sell', 10004, 115)
        acc.info_engine.put_deal(o1, 3)
        print(cal.get_profit())

    def try_value_updater():
        class DynamicGetFn:
            def __init__(self):
                self.count = 0
                self.count_2 = 100

            def get_fn(self):
                print('get()')
                self.count += 1
                self.count_2 += 100
                result = {str(i): self.count_2 + i for i in range(self.count)}
                return result

        updater = ValueUpdater('try', DynamicGetFn().get_fn, 1)

        while True:
            get = input('get:')
            a = updater.get_value(get)
            print(a)

    def try_rpc():
        def fn(a, b, c):
            print('call...')
            print('fn({}, {}, {})'.format(a, b, c))

        server = RpcServer()
        server.start()
        server.register('func', fn)

        client = RpcClient('localhost', 5555)
        client.start()

        a = client.call('func', 1, 2, c=3)
        print(a)

        a = client.call('func', 1, 2, c=3)
        print(a)

        a = client.call('func', 1, 2, c=3)
        print(a)

        a = client.call('func', 1, 2, c=3)
        print(a)

    def try_recent_collect():
        from quant.markets import Markets

        mds = [
            [Event.Book,  Exchange.Binance, 'doge/usdt.swap'],
            [Event.Trade, Exchange.Binance, 'doge/usdt.swap'],

            [Event.Book,  Exchange.Binance, 'doge/usdt'],
            [Event.Trade, Exchange.Binance, 'doge/usdt'],

            [Event.Book,  Exchange.Okex,    'doge/usdt.swap'],
            [Event.Trade, Exchange.Okex,    'doge/usdt.swap'],

            [Event.Book,  Exchange.Okex,    'doge/usdt'],
            [Event.Trade, Exchange.Okex,    'doge/usdt'],

            [Event.Book,  Exchange.Huobi,   'doge/usdt.swap'],
            [Event.Trade, Exchange.Huobi,   'doge/usdt.swap'],

            [Event.Book,  Exchange.Huobi,   'doge/usdt'],
            [Event.Trade, Exchange.Huobi,   'doge/usdt'],
        ]

        mar = Markets()
        for d in mds:
            mar.add_market(*d)

        RecentRawCollector('doge-15t', mar, FileRawLogger, 1, 2)

        while True:
            input(':')

    def try_ding():
        ding = Ding.create_default_ding()
        ding.send('dingdingding')
        input(':')

    def create_influx_db():
        import servers

        host = servers.get_ip('raw')
        influx = InfluxClient(host, 8086)

        influx.create_database('attempts')

        # for p in influx.client.get_list_retention_policies('modify_quick'):
        #     print(p)

        # print(influx.query('drop measurement switch', 'running_quick'))

        while True:
            influx.show_storage()
            input(':')

    create_influx_db()






















